import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow.contrib as tcon
import PIL
import time
from IPython import display

# shape:(60000,28,28)
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
# shape:[batch_size,height,width,channel]
train_images_reshape=tf.reshape(train_images,shape=(train_images.shape[0],28,28,1)).astype(tf.float32)
# 缩放图片[-1,1]
train_images_nor=(train_images-127.5)/127.5


BUFFER_SIZE=60000
BATCH_SIZE=256

# 优化输入管道需要从：读取，转换，加载三方面考虑。
train_dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(buffer_size=BUFFER_SIZE).batch(BATCH_SIZE)


def make_generator_model():
    # 反卷积，从后往前
    model = tf.keras.Sequential()
    model.add(
        tf.keras.layers.Dense(
            input_dim=7 * 7 * 256,

            # 不使用bias的原因是我们使用了BN，BN会抵消掉bias的作用。
            # bias的作用：
            # 提升网络拟合能力，而且计算简单（只要一次加法）。
            # 能力的提升源于调整输出的整体分布
            use_bias=False,
            # noise dim
            input_shape=(100,)
        )
    )
    """
    随着神经网络的训练，网络层的输入分布会发生变动，逐渐向激活函数取值两端靠拢，如：sigmoid激活函数，
    此时会进入饱和状态，梯度更新缓慢，对输入变动不敏感，甚至梯度消失导致模型难以训练。
    BN，在网络层输入激活函数输入值之前加入，可以将分布拉到均值为0，标准差为1的正态分布，从而
    使激活函数处于对输入值敏感的区域，从而加快模型训练。此外，BN还能起到类似dropout的正则化作用，由于我们会有
    ‘强拉’操作，所以对初始化要求没有那么高，可以使用较大的学习率。
    """
    model.add(tf.keras.layers.BatchNormalization())
    """
    relu 激活函数在输入为负值的时候，激活值为0，此时神经元无法学习
    leakyrelu 激活函数在输入为负值的时候，激活值不为0（但值很小），神经元可以继续学习
    """
    model.add(tf.keras.layers.LeakyReLU())

    model.add(tf.keras.layers.Reshape(input_shape=(7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)

    model.add(tf.keras.layers.Conv2DTranspose(
        filters=128,
        kernel_size=5,
        strides=1,
        padding='same',
        use_bias='False'
    ))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())

    # 卷积核为奇数：图像两边可以对称padding 00xxxx00
    model.add(tf.keras.layers.Conv2DTranspose(
        filters=64,
        kernel_size=5,
        strides=2,
        padding='same',
        use_bias='False'
    ))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU())

    model.add(tf.keras.layers.Conv2DTranspose(
        filters=1,
        kernel_size=5,
        strides=2,
        padding='same',
        use_bias='False',

        # tanh激活函数值区间[-1,1]，均值为0关于原点中心对称。、
        # sigmoid激活函数梯度在反向传播过程中会出全正数或全负数，导致权重更新出现Z型下降。
        activation='tanh'
    ))
    assert model.output_shape == (None, 28, 28, 1)

    return model


def make_discriminator_model():
    # 常规卷积操作
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())

    # dropout常见于全连接层，其实卷积层也是可以使用的。
    # 这里简单翻译下dropout论文观点：
    """
    可能很多人认为因为卷积层参数较少，过拟合发生概率较低，所以dropout作用并不大。
    但是，dropout在前面几层依然有帮助，因为它为后面的全连接层提供了加噪声的输入，从而防止过拟合。
    """
    model.add(tf.keras.layers.Dropout(0.3))

    model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())
    model.add(tf.keras.layers.Dropout(0.3))

    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(1))

    return model

generator = make_generator_model()
discriminator = make_discriminator_model()

def generator_loss(generator_output):
    return tf.losses.sigmoid_cross_entropy(
        multi_class_labels=tf.ones_like(generator_output),
        logits=generator_output
    )


def discriminator_loss(real_output, generated_output):
    # real:[1,1,...,1]
    real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)
    # ：generated:[0,0,...,0]
    generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output),
                                                     logits=generated_output)

    # 总损失为两者相加
    total_loss = real_loss + generated_loss
    return total_loss

# 两种模型同时训练，自然需要使用两种优化器，学习率为：0.0001
generator_optimizer = tf.train.AdamOptimizer(1e-4)
discriminator_optimizer = tf.train.AdamOptimizer(1e-4)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

# checkpoint配置
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
# 数据集迭代次数
EPOCHS = 50
# 生成器噪声维度
noise_dim = 100

# 可视化效果数量设置
num_examples_to_generate = 16
random_vector_for_generation = tf.random_normal([num_examples_to_generate,
                                                 noise_dim])


def train_step(images):
    # 正态分布噪声作为生成器输入
    noise = tf.random_normal([BATCH_SIZE, noise_dim])

    # tf.GradientTape进行记录
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        # 判别器中真实图像和生成器的假图像
        real_output = discriminator(images, training=True)
        generated_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(generated_output)
        disc_loss = discriminator_loss(real_output, generated_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))

train_step = tf.contrib.eager.defun(train_step)


def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        # 迭代数据集
        for images in dataset:
            train_step(images)

        display.clear_output(wait=True)

        # 保存图像用于后面的可视化
        generate_and_save_images(generator,
                                 epoch + 1,
                                 random_vector_for_generation)

        # 每迭代15次数据集保存一次模型
        # 如需部署至tensorflow serving需要使用savemodel
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

        print('Time taken for epoch {} is {} sec'.format(epoch + 1,
                                                         time.time() - start))
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epochs,
                             random_vector_for_generation)


def generate_and_save_images(model, epoch, test_input):
    # training:False 不训练BN
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


train(train_dataset, EPOCHS)

def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))


with imageio.get_writer('dcgan.gif', mode='I') as writer:
    filenames = glob.glob('image*.png')
    filenames = sorted(filenames)
    last = -1
    for i, filename in enumerate(filenames):
        frame = 2 * (i ** 0.5)
        if round(frame) > round(last):
            last = frame
        else:
            continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

os.system('cp dcgan.gif dcgan.gif.png')
display.Image(filename="dcgan.gif.png")